{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attention Saccades\n",
    "\n",
    "In the paper, we show that we can attend over a scene graph and localize each entity in a sequence. This notebook allows you to visualize some user specified scene graphs.\n",
    "\n",
    "Note that this notebook is not commented and might require some changes to run. Feel free to update it and send us a pull request."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iterator import SmartIterator\n",
    "from utils.visualization_utils import get_att_map, objdict, get_dict, add_attention, add_bboxes, get_bbox_from_heatmap, add_bbox_to_image\n",
    "from keras.models import load_model\n",
    "from models import ReferringRelationshipsModel\n",
    "from keras.utils import to_categorical\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "from keras.models import Model\n",
    "import json\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import h5py\n",
    "from keras.models import Model\n",
    "import keras.backend as K\n",
    "from keras.layers import Dense, Flatten, UpSampling2D, Input\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "from scipy.misc import imresize\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 34})\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###################\n",
    "img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'\n",
    "###################\n",
    "annotations_file = \"data/VRD/annotations_test.json\"\n",
    "vocab_dir = os.path.join('data/VRD')\n",
    "model_checkpoint = \"pretrained/vrd.h5\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "annotations_test = json.load(open(annotations_file))\n",
    "predicate_dict, obj_subj_dict = get_dict(vocab_dir)\n",
    "image_ids = sorted(list(annotations_test.keys()))[:1000]\n",
    "params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), \"args.json\"), \"r\")))\n",
    "params.cnn = 'resnet'\n",
    "params.discovery = False\n",
    "relationships_model = ReferringRelationshipsModel(params)\n",
    "test_generator = SmartIterator(params.test_data_dir, params)\n",
    "images = test_generator.get_image_dataset()\n",
    "print(' | '.join(obj_subj_dict))\n",
    "print('')\n",
    "print(' | '.join(predicate_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = relationships_model.build_model()\n",
    "model.load_weights(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_im = Input(shape=(params.input_dim, params.input_dim, 3))\n",
    "input_pred = Input(shape=(params.num_predicates,))\n",
    "input_obj = Input(shape=(1,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image model that returns image feature maps\n",
    "im_output = model.get_layer(\"conv2d_1\").output\n",
    "image_model = Model(inputs=model.inputs, outputs=im_output)\n",
    "\n",
    "# Embedding weights that returns object embeddings\n",
    "model_weights = h5py.File(model_checkpoint)\n",
    "embeddings = model_weights[\"embedding_1\"][\"embedding_1\"][\"embeddings:0\"][()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "convs = {}\n",
    "for i, predicate in enumerate(predicate_dict):\n",
    "    convs[predicate] = []\n",
    "    for j in range(params.nb_conv_att_map):\n",
    "        layer_name = \"conv{}-predicate{}\".format(j, i)\n",
    "        convs[predicate] += [model_weights[layer_name][layer_name][\"kernel:0\"][()]]\n",
    "        \n",
    "convs_T = []     \n",
    "upsampling_factor = params.input_dim / params.feat_map_dim\n",
    "k = int(np.log(upsampling_factor) / np.log(2))\n",
    "for i in range(k):\n",
    "    layer_name = \"subject-convT-{}\".format(i)\n",
    "    convs_T += [model_weights[layer_name][layer_name][\"kernel:0\"][()]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shift(att, convs, predicate):\n",
    "    att = K.constant(att)\n",
    "    for j in range(params.nb_conv_att_map):\n",
    "        kernel = convs[predicate][j]\n",
    "        att = K.conv2d(att, kernel, padding='same', data_format='channels_last')\n",
    "        att = K.relu(att)\n",
    "    att = att.eval()\n",
    "    shifted_att = np.tanh(att)\n",
    "    return att\n",
    "\n",
    "def get_att(obj_idx, embeddings, im_features):\n",
    "    obj_emb = embeddings[obj_idx,:].reshape((1, 1, 1, im_features.shape[-1]))\n",
    "    att = (im_features*obj_emb).sum(axis=3, keepdims=True)\n",
    "    #att = np.tanh(att)\n",
    "    att = (att>0)*att\n",
    "    return att\n",
    "\n",
    "def upsample(att, convs_transpose, k):\n",
    "    _, shape, _, _ = att.shape\n",
    "    att = K.constant(att)\n",
    "    for i in range(k):\n",
    "        kernel = convs_T[i]\n",
    "        att = K.repeat_elements(att, 2, axis=1)\n",
    "        att = K.repeat_elements(att, 2, axis=2)\n",
    "        att = K.conv2d_transpose(att, kernel, padding='same', output_shape=(1, (2**(i+1)) * shape, (2**(i+1))*shape, 1))\n",
    "        att = K.relu(att)\n",
    "    att = K.tanh(att)\n",
    "    att = att.eval()\n",
    "    return att[0, :, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################\n",
    "image_index = np.random.randint(1000)\n",
    "print(image_index)\n",
    "#################\n",
    "img = Image.open(os.path.join(img_dir, image_ids[image_index]))\n",
    "img = img.resize((params.input_dim, params.input_dim))\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.imshow(img)\n",
    "plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicate_id = np.zeros((1, params.num_predicates))\n",
    "obj_id = np.zeros((1, 1))\n",
    "im_features = image_model.predict([images[image_index:image_index+1], \n",
    "                                   np.zeros((1, 1)), \n",
    "                                   np.zeros((1, params.num_predicates)), \n",
    "                                   np.zeros((1, 1))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "objects = [\"plate\", \"table\", \"person\"]\n",
    "predicates = [\"on\", \"on the right of\"]\n",
    "nb_plots = 2 + 2 * len(predicates)\n",
    "att = get_att(obj_subj_dict.index(objects[0]), embeddings, im_features)\n",
    "fig, axes = plt.subplots(1, nb_plots, figsize=(20, 5))\n",
    "for ax in axes:\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "ax_counter = 0\n",
    "axes[ax_counter].imshow(img)\n",
    "axes[ax_counter].set_xlabel(\"input image\", {'fontsize': 18})\n",
    "ax_counter += 1\n",
    "axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')\n",
    "axes[ax_counter].set_xlabel(objects[0], {'fontsize': 18})\n",
    "for i in range(len(objects)-1):\n",
    "    ax_counter += 1\n",
    "    shifted_att = shift(att, convs, predicates[i])\n",
    "    axes[ax_counter].set_xlabel(predicates[i], {'fontsize': 18})\n",
    "    axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')\n",
    "    att = get_att(obj_subj_dict.index(objects[i+1]), embeddings, im_features*shifted_att)\n",
    "    ax_counter += 1\n",
    "    axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')\n",
    "    axes[ax_counter].set_xlabel(objects[i+1], {'fontsize': 18})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize saccades\n",
    "\n",
    "The cell below is attempting to attend over 'plate -> on -> table -> on the right of -> person'. We encode the entities in the objects list and the predicates in the predicates list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#############################\n",
    "objects = [\"plate\", \"table\", \"person\"]\n",
    "predicates = [\"on\", \"on the right of\"]\n",
    "threshold = [0.5, 0.3, 0.2] \n",
    "#############################\n",
    "ncols = 2*(len(objects) + 1)\n",
    "nrows = 4\n",
    "fig = plt.figure(figsize=(14, 6))\n",
    "\n",
    "ax = plt.subplot2grid((nrows, ncols), (1, 0), colspan=2, rowspan=2)\n",
    "ax.imshow(img)\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "\n",
    "att = img\n",
    "features = im_features\n",
    "for i in range(len(objects)):\n",
    "    att = get_att(obj_subj_dict.index(objects[i]), embeddings, features)\n",
    "    up_att = upsample(att, convs_T, k)\n",
    "    ax = plt.subplot2grid((nrows, ncols), (0, 2*i+2), colspan=2, rowspan=2)\n",
    "    ax.imshow(up_att, interpolation='spline16')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    \n",
    "    bbox = get_bbox_from_heatmap(up_att, threshold=threshold[i])\n",
    "    bboxed_image = add_bbox_to_image(img, bbox, color='red', width=3)\n",
    "    ax = plt.subplot2grid((nrows, ncols), (2, 2*i+2), colspan=2, rowspan=2)\n",
    "    ax.imshow(bboxed_image)\n",
    "    axes[ax_counter].set_xlabel(objects[i], {'fontsize': 18})\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    \n",
    "    if i >= len(predicates):\n",
    "        break\n",
    "    att = shift(att, convs, predicates[i])\n",
    "    features = features*shifted_att\n",
    "    \n",
    "plt.tight_layout(pad=0.1, w_pad=-1, h_pad=-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
